{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 人脸检测和识别训练流程\n",
    "\n",
    "以下示例展示了如何在自己的数据集上微调InceptionResnetV1模型。这将主要遵循标准的PyTorch训练模式。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from facenet_pytorch import MTCNN, InceptionResnetV1, fixed_image_standardization, training\n",
    "import torch\n",
    "from torch.utils.data import DataLoader, SubsetRandomSampler\n",
    "from torch import optim\n",
    "from torch.optim.lr_scheduler import MultiStepLR\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "from torchvision import datasets, transforms\n",
    "import numpy as np\n",
    "import os"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 定义运行参数\n",
    "\n",
    "数据集应该遵循VGGFace2/ImageNet风格的目录布局。将`data_dir`修改为您要微调的数据集所在的位置。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = '../data/test_images'\n",
    "\n",
    "batch_size = 32\n",
    "epochs = 8\n",
    "workers = 0 if os.name == 'nt' else 8"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 判断是否有nvidia GPU可用"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "print('在该设备上运行: {}'.format(device))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 定义MTCNN模块\n",
    "\n",
    "查看`help(MTCNN)`获取更多细节。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mtcnn = MTCNN(\n",
    "    image_size=160, margin=0, min_face_size=20,\n",
    "    thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,\n",
    "    device=device\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 执行MTCNN人脸检测\n",
    "\n",
    "迭代DataLoader对象并获取裁剪后的人脸。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "dataset = datasets.ImageFolder(data_dir, transform=transforms.Resize((512, 512)))\n",
    "dataset.samples = [\n",
    "    (p, p.replace(data_dir, data_dir + '_cropped'))\n",
    "        for p, _ in dataset.samples\n",
    "]\n",
    "        \n",
    "loader = DataLoader(\n",
    "    dataset,\n",
    "    num_workers=workers,\n",
    "    batch_size=batch_size,\n",
    "    collate_fn=training.collate_pil\n",
    ")\n",
    "\n",
    "for i, (x, y) in enumerate(loader):\n",
    "    mtcnn(x, save_path=y)\n",
    "    print('\\r第 {} 批，共 {} 批'.format(i + 1, len(loader)), end='')\n",
    "    \n",
    "# Remove mtcnn to reduce GPU memory usage\n",
    "del mtcnn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 定义Inception Resnet V1模块\n",
    "\n",
    "查看`help(InceptionResnetV1)`获取更多细节。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "resnet = InceptionResnetV1(\n",
    "    classify=True,\n",
    "    pretrained='vggface2',\n",
    "    num_classes=len(dataset.class_to_idx)\n",
    ").to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 定义优化器、调度器、数据集和数据加载器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = optim.Adam(resnet.parameters(), lr=0.001)\n",
    "scheduler = MultiStepLR(optimizer, [5, 10])\n",
    "\n",
    "trans = transforms.Compose([\n",
    "    np.float32,\n",
    "    transforms.ToTensor(),\n",
    "    fixed_image_standardization\n",
    "])\n",
    "dataset = datasets.ImageFolder(data_dir + '_cropped', transform=trans)\n",
    "img_inds = np.arange(len(dataset))\n",
    "np.random.shuffle(img_inds)\n",
    "train_inds = img_inds[:int(0.8 * len(img_inds))]\n",
    "val_inds = img_inds[int(0.8 * len(img_inds)):]\n",
    "\n",
    "train_loader = DataLoader(\n",
    "    dataset,\n",
    "    num_workers=workers,\n",
    "    batch_size=batch_size,\n",
    "    sampler=SubsetRandomSampler(train_inds)\n",
    ")\n",
    "val_loader = DataLoader(\n",
    "    dataset,\n",
    "    num_workers=workers,\n",
    "    batch_size=batch_size,\n",
    "    sampler=SubsetRandomSampler(val_inds)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 定义损失和评估函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_fn = torch.nn.CrossEntropyLoss()\n",
    "metrics = {\n",
    "    'fps': training.BatchTimer(),\n",
    "    'acc': training.accuracy\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "writer = SummaryWriter()\n",
    "writer.iteration, writer.interval = 0, 10\n",
    "\n",
    "print('\\n\\n初始化')\n",
    "print('-' * 10)\n",
    "resnet.eval()\n",
    "training.pass_epoch(\n",
    "    resnet, loss_fn, val_loader,\n",
    "    batch_metrics=metrics, show_running=True, device=device,\n",
    "    writer=writer\n",
    ")\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    print('\\n循环 {}/{}'.format(epoch + 1, epochs))\n",
    "    print('-' * 10)\n",
    "\n",
    "    resnet.train()\n",
    "    training.pass_epoch(\n",
    "        resnet, loss_fn, train_loader, optimizer, scheduler,\n",
    "        batch_metrics=metrics, show_running=True, device=device,\n",
    "        writer=writer\n",
    "    )\n",
    "\n",
    "    resnet.eval()\n",
    "    training.pass_epoch(\n",
    "        resnet, loss_fn, val_loader,\n",
    "        batch_metrics=metrics, show_running=True, device=device,\n",
    "        writer=writer\n",
    "    )\n",
    "\n",
    "writer.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
